{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linear text segmentation\n",
    "\n",
    "<!-- {{ add_binder_block(page) }} -->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "Linear text segmentation consists in dividing a text into several meaningful segments.\n",
    "Linear text segmentation can be seen as a change point detection task and therefore can be carried out with `ruptures`. \n",
    "This example performs exactly that on a well-known data set intoduced in [[Choi2000](#Choi2000)]."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "First we import packages and define a few utility functions.\n",
    "This section can be skipped at first reading.\n",
    "\n",
    "**Library imports.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import nltk\n",
    "import numpy as np\n",
    "import ruptures as rpt  # our package\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.stem import PorterStemmer\n",
    "from nltk.tokenize import regexp_tokenize\n",
    "from ruptures.base import BaseCost\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "from matplotlib.colors import LogNorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nltk.download(\"stopwords\")\n",
    "STOPWORD_SET = set(\n",
    "    stopwords.words(\"english\")\n",
    ")  # set of stopwords of the English language\n",
    "PUNCTUATION_SET = set(\"!\\\"#$%&'()*+,-./:;<=>?@[\\\\]^_`{|}~\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Utility functions.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(list_of_sentences: list) -> list:\n",
    "    \"\"\"Preprocess each sentence (remove punctuation, stopwords, then stemming.)\"\"\"\n",
    "    transformed = list()\n",
    "    for sentence in list_of_sentences:\n",
    "        ps = PorterStemmer()\n",
    "        list_of_words = regexp_tokenize(text=sentence.lower(), pattern=\"\\w+\")\n",
    "        list_of_words = [\n",
    "            ps.stem(word) for word in list_of_words if word not in STOPWORD_SET\n",
    "        ]\n",
    "        transformed.append(\" \".join(list_of_words))\n",
    "    return transformed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_square_on_ax(start, end, ax, linewidth=0.8):\n",
    "    \"\"\"Draw a square on the given ax object.\"\"\"\n",
    "    ax.vlines(\n",
    "        x=[start - 0.5, end - 0.5],\n",
    "        ymin=start - 0.5,\n",
    "        ymax=end - 0.5,\n",
    "        linewidth=linewidth,\n",
    "    )\n",
    "    ax.hlines(\n",
    "        y=[start - 0.5, end - 0.5],\n",
    "        xmin=start - 0.5,\n",
    "        xmax=end - 0.5,\n",
    "        linewidth=linewidth,\n",
    "    )\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Description\n",
    "\n",
    "The text to segment is a concatenation of excerpts from ten different documents randomly selected from the so-called Brown corpus (described [here](http://icame.uib.no/brown/bcm.html)).\n",
    "Each excerpt has nine to eleven sentences, amounting to 99 sentences in total.\n",
    "The complete text is shown in [Appendix A](#appendix-a).\n",
    "\n",
    "These data stem from a larger data set which is thoroughly described in [[Choi2000](#Choi2000)] and can be downloaded [here](https://web.archive.org/web/20030206011734/http://www.cs.man.ac.uk/~mary/choif/software/C99-1.2-release.tgz).\n",
    "This is a common benchmark to evaluate text segmentation methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loading the text\n",
    "filepath = Path(\"../data/text-segmentation-data.txt\")\n",
    "original_text = filepath.read_text().split(\"\\n\")\n",
    "TRUE_BKPS = [11, 20, 30, 40, 49, 59, 69, 80, 90, 99]  # read from the data description\n",
    "\n",
    "print(f\"There are {len(original_text)} sentences, from {len(TRUE_BKPS)} documents.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The objective is to automatically recover the boundaries of the 10 excerpts, using the fact that they come from quite different documents and therefore have distinct topics.\n",
    "\n",
    "For instance, in the small extract of text printed in the following cell, an accurate text segmentation procedure would be able to detect that the first two sentences (10 and 11) and the last three sentences (12 to 14) belong to two different documents and have very different semantic fields."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print 5 sentences from the original text\n",
    "start, end = 9, 14\n",
    "for line_number, sentence in enumerate(original_text[start:end], start=start + 1):\n",
    "    sentence = sentence.strip(\"\\n\")\n",
    "    print(f\"{line_number:>2}: {sentence}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing\n",
    "\n",
    "Before performing text segmentation, the original text is preprocessed.\n",
    "In a nutshell (see [[Choi2000](#Choi2000)] for more details),\n",
    "\n",
    "- the punctuation and stopwords are removed;\n",
    "- words are reduced to their stems (e.g., \"waited\" and \"waiting\" become \"wait\");\n",
    "- a vector of word counts is computed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# transform text\n",
    "transformed_text = preprocess(original_text)\n",
    "# print original and transformed\n",
    "ind = 97\n",
    "print(\"Original sentence:\")\n",
    "print(f\"\\t{original_text[ind]}\")\n",
    "print()\n",
    "print(\"Transformed:\")\n",
    "print(f\"\\t{transformed_text[ind]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Once the text is preprocessed, each sentence is transformed into a vector of word counts.\n",
    "vectorizer = CountVectorizer(analyzer=\"word\")\n",
    "vectorized_text = vectorizer.fit_transform(transformed_text)\n",
    "\n",
    "msg = f\"There are {len(vectorizer.get_feature_names_out())} different words in the corpus, e.g. {vectorizer.get_feature_names_out()[20:30]}.\"\n",
    "print(msg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the vectorized text representation is a (very) sparse matrix."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text segmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cost function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To compare (the vectorized representation of) two sentences, [[Choi2000]](#Choi2000) uses the cosine similarity $k_{\\text{cosine}}: \\mathbb{R}^d \\times \\mathbb{R}^d \\rightarrow \\mathbb{R}$:\n",
    "\n",
    "$$ k_{\\text{cosine}}(x, y) := \\frac{\\langle x \\mid y \\rangle}{\\|x\\|\\|y\\|} $$\n",
    "\n",
    "where $x$ and $y$ are two $d$-dimensionnal vectors of word counts.\n",
    "\n",
    "Text segmentation now amounts to a kernel change point detection (see [[Truong2020]](#Truong2020) for more details).\n",
    "However, this particular kernel is not implemented in `ruptures` therefore we need to create a [custom cost function](../../user-guide/costs/costcustom).\n",
    "(Actually, it is implemented in `ruptures` but the current implementation does not exploit the sparse structure of the vectorized text representation and can therefore be slow.)\n",
    "\n",
    "Let $y=\\{y_0, y_1,\\dots,y_{T-1}\\}$ be a $d$-dimensionnal signal with $T$ samples.\n",
    "Recall that a cost function $c(\\cdot)$ that derives from a kernel $k(\\cdot, \\cdot)$ is such that\n",
    "\n",
    "$$\n",
    "c(y_{a..b}) = \\sum_{t=a}^{b-1} G_{t, t} - \\frac{1}{b-a} \\sum_{a \\leq s < b } \\sum_{a \\leq t < b} G_{s,t}\n",
    "$$\n",
    "\n",
    "where $y_{a..b}$ is the subsignal $\\{y_a, y_{a+1},\\dots,y_{b-1}\\}$ and $G_{st}:=k(y_s, y_t)$ (see [[Truong2020]](#Truong2020) for more details).\n",
    "In other words, $(G_{st})_{st}$ is the $T\\times T$ Gram matrix of $y$.\n",
    "Thanks to this formula, we can now implement our custom cost function (named `CosineCost` in the following cell)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CosineCost(BaseCost):\n",
    "    \"\"\"Cost derived from the cosine similarity.\"\"\"\n",
    "\n",
    "    # The 2 following attributes must be specified for compatibility.\n",
    "    model = \"custom_cosine\"\n",
    "    min_size = 2\n",
    "\n",
    "    def fit(self, signal):\n",
    "        \"\"\"Set the internal parameter.\"\"\"\n",
    "        self.signal = signal\n",
    "        self.gram = cosine_similarity(signal, dense_output=False)\n",
    "        return self\n",
    "\n",
    "    def error(self, start, end) -> float:\n",
    "        \"\"\"Return the approximation cost on the segment [start:end].\n",
    "\n",
    "        Args:\n",
    "            start (int): start of the segment\n",
    "            end (int): end of the segment\n",
    "        Returns:\n",
    "            segment cost\n",
    "        Raises:\n",
    "            NotEnoughPoints: when the segment is too short (less than `min_size` samples).\n",
    "        \"\"\"\n",
    "        if end - start < self.min_size:\n",
    "            raise NotEnoughPoints\n",
    "        sub_gram = self.gram[start:end, start:end]\n",
    "        val = sub_gram.diagonal().sum()\n",
    "        val -= sub_gram.sum() / (end - start)\n",
    "        return val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute change points\n",
    "\n",
    "If the number $K$ of change points is assumed to be known, we can use [dynamic programming](../../user-guide/detection/dynp) to search for the exact segmentation $\\hat{t}_1,\\dots,\\hat{t}_K$ that minimizes the sum of segment costs:\n",
    "\n",
    "$$\n",
    "\\hat{t}_1,\\dots,\\hat{t}_K := \\text{arg}\\min_{t_1,\\dots,t_K} \\left[ c(y_{0..t_1}) + c(y_{t_1..t_2}) + \\dots + c(y_{t_K..T}) \\right].\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_bkps = 9  # there are 9 change points (10 text segments)\n",
    "\n",
    "algo = rpt.Dynp(custom_cost=CosineCost(), min_size=2, jump=1).fit(vectorized_text)\n",
    "predicted_bkps = algo.predict(n_bkps=n_bkps)\n",
    "\n",
    "print(f\"True change points are\\t\\t{TRUE_BKPS}.\")\n",
    "print(f\"Detected change points are\\t{predicted_bkps}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(Note that the last change point index is simply the length of the signal. This is by design.)\n",
    "\n",
    "Predicted breakpoints are quite close to the true change points.\n",
    "Indeed, most estimated changes are less than one sentence away from a true change.\n",
    "The last change is less accurately predicted with an error of 4 sentences.\n",
    "To overcome this issue, one solution would be to consider a richer representation (compared to the sparse word frequency vectors)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize segmentations\n",
    "\n",
    "**Show sentence numbers.**\n",
    "\n",
    "In the following cell, the two segmentations (true and predicted) can be visually compared.\n",
    "For each paragraph, the sentence numbers are shown."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_segment_list = rpt.utils.pairwise([0] + TRUE_BKPS)\n",
    "predicted_segment_list = rpt.utils.pairwise([0] + predicted_bkps)\n",
    "\n",
    "for n_paragraph, (true_segment, predicted_segment) in enumerate(\n",
    "    zip(true_segment_list, predicted_segment_list), start=1\n",
    "):\n",
    "    print(f\"Paragraph n°{n_paragraph:02d}\")\n",
    "    start_true, end_true = true_segment\n",
    "    start_pred, end_pred = predicted_segment\n",
    "\n",
    "    start = min(start_true, start_pred)\n",
    "    end = max(end_true, end_pred)\n",
    "    msg = \" \".join(\n",
    "        f\"{ind+1:02d}\" if (start_true <= ind < end_true) else \"  \"\n",
    "        for ind in range(start, end)\n",
    "    )\n",
    "    print(f\"(true)\\t{msg}\")\n",
    "    msg = \" \".join(\n",
    "        f\"{ind+1:02d}\" if (start_pred <= ind < end_pred) else \"  \"\n",
    "        for ind in range(start, end)\n",
    "    )\n",
    "    print(f\"(pred)\\t{msg}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Show the Gram matrix.**\n",
    "\n",
    "In addition, the text segmentation can be shown on the Gram matrix that was used to detect changes.\n",
    "This is done in the following cell.\n",
    "\n",
    "Most segments (represented by the blue squares) are similar between the true segmentation and the predicted segmentation, except for last two.\n",
    "This is mainly due to the fact that, in the penultimate excerpt, all sentences are dissimilar (with respect to the cosine measure)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax_arr = plt.subplots(nrows=1, ncols=2, figsize=(7, 5), dpi=200)\n",
    "\n",
    "# plot config\n",
    "title_fontsize = 10\n",
    "label_fontsize = 7\n",
    "title_list = [\"True text segmentation\", \"Predicted text segmentation\"]\n",
    "\n",
    "for ax, title, bkps in zip(ax_arr, title_list, [TRUE_BKPS, predicted_bkps]):\n",
    "    # plot gram matrix\n",
    "    ax.imshow(algo.cost.gram.toarray(), cmap=cm.plasma, norm=LogNorm())\n",
    "    # add text segmentation\n",
    "    for start, end in rpt.utils.pairwise([0] + bkps):\n",
    "        draw_square_on_ax(start=start, end=end, ax=ax)\n",
    "    # add labels and title\n",
    "    ax.set_title(title, fontsize=title_fontsize)\n",
    "    ax.set_xlabel(\"Sentence index\", fontsize=label_fontsize)\n",
    "    ax.set_ylabel(\"Sentence index\", fontsize=label_fontsize)\n",
    "    ax.tick_params(axis=\"both\", which=\"major\", labelsize=label_fontsize)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "This example shows how to apply `ruptures` on a text segmentation task.\n",
    "In detail, we detected shifts in the vocabulary of a collection of sentences using common NLP preprocessing and transformation.\n",
    "This task amounts to a kernel change point detection procedure where the kernel is the cosine kernel.\n",
    "\n",
    "Such results can then be used to characterize the structure of the text for subsequent NLP tasks.\n",
    "This procedure should certainly be enriched with more relevant and compact representations to better detect changes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Appendix A\n",
    "\n",
    "The complete text used in this notebook is as follows.\n",
    "Note that the line numbers and the blank lines (added to visually mark the boundaries between excerpts) are not part of the text fed to the segmentation method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for start, end in rpt.utils.pairwise([0] + TRUE_BKPS):\n",
    "    excerpt = original_text[start:end]\n",
    "    for n_line, sentence in enumerate(excerpt, start=start + 1):\n",
    "        sentence = sentence.strip(\"\\n\")\n",
    "        print(f\"{n_line:>2}: {sentence}\")\n",
    "    print()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Authors\n",
    "\n",
    "This example notebook has been authored by Olivier Boulant and edited by Charles Truong."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "<a id=\"Choi2000\">[Choi2000]</a>\n",
    "Choi, F. Y. Y. (2000). Advances in domain independent linear text segmentation. Proceedings of the North American Chapter of the Association for Computational Linguistics Conference (NAACL), 26–33.\n",
    "\n",
    "<a id=\"Truong2020\">[Truong2020]</a>\n",
    "Truong, C., Oudre, L., & Vayatis, N. (2020). Selective review of offline change point detection methods. Signal Processing, 167."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ruptures",
   "language": "python",
   "name": "ruptures"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
